/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
 */
template <bool kIsSquaredL2>
SCANN_SIMD_INLINE float32x4_t FusedMultiplyOpNeon(float32x4_t a, float32x4_t b, float32x4_t mult, float32x4_t accum)
{
    if constexpr (kIsSquaredL2) {
        float32x4_t diff = vfmsq_f32(a, b, mult);
        return vfmaq_f32(accum, diff, diff);
    } else {
        return vfmsq_f32(accum, a, b);
    }
}

template <size_t kNumDims, bool kIsSquaredL2, size_t kUnrollBy>
SCANN_SIMD_INLINE std::enable_if_t<std::is_same<Simd<float>, Sse4<float>>::value, Simd<float, kUnrollBy>> HandleXDims(
    const float *query, array<const int8_t *, kUnrollBy> ptrs, const float *inv_multipliers_for_squared_l2, size_t dim,
    Simd<float, kUnrollBy> accums)
{
    DCHECK(kNumDims == 4 || kNumDims == 8 || kNumDims == 16)
        << "Error: kNumDims must be 4, 8, or 16.";


    if constexpr (kNumDims == 16) {
        float32x4x4_t mult;
        if constexpr (kIsSquaredL2) {
            mult = vld1q_f32_x4(inv_multipliers_for_squared_l2 + dim);
        }

        float32x4x4_t q = vld1q_f32_x4(query + dim);
        for (size_t j : Seq(kUnrollBy)) {
            int8x16_t data = vld1q_s8(ptrs[j] + dim);
            int16x8_t data_l = vmovl_s8(vget_low_s8(data));
            int16x8_t data_h = vmovl_high_s8(data);
            int32x4_t data0 = vmovl_s16(vget_low_s16(data_l));
            int32x4_t data1 = vmovl_high_s16(data_l);
            int32x4_t data2 = vmovl_s16(vget_low_s16(data_h));
            int32x4_t data3 = vmovl_high_s16(data_h);
            accums[j] = {
                FusedMultiplyOpNeon<kIsSquaredL2>(q.val[0], vcvtq_f32_s32(data0), mult.val[0], (*accums[j]).vect_f32)};
            accums[j] = {
                FusedMultiplyOpNeon<kIsSquaredL2>(q.val[1], vcvtq_f32_s32(data1), mult.val[1], (*accums[j]).vect_f32)};
            accums[j] = {
                FusedMultiplyOpNeon<kIsSquaredL2>(q.val[2], vcvtq_f32_s32(data2), mult.val[2], (*accums[j]).vect_f32)};
            accums[j] = {
                FusedMultiplyOpNeon<kIsSquaredL2>(q.val[3], vcvtq_f32_s32(data3), mult.val[3], (*accums[j]).vect_f32)};
        }
    }
    if constexpr (kNumDims == 8) {
        float32x4x2_t mult;
        if constexpr (kIsSquaredL2) {
            mult = vld1q_f32_x2(inv_multipliers_for_squared_l2 + dim);
        }

        float32x4x2_t q = vld1q_f32_x2(query + dim);
        for (size_t j : Seq(kUnrollBy)) {
            int8x16_t data = vld1q_s8(ptrs[j] + dim);
            int16x8_t data_l = vmovl_s8(vget_low_s8(data));
            int32x4_t data0 = vmovl_s16(vget_low_s16(data_l));
            int32x4_t data1 = vmovl_high_s16(data_l);
            accums[j] = {
                FusedMultiplyOpNeon<kIsSquaredL2>(q.val[0], vcvtq_f32_s32(data0), mult.val[0], (*accums[j]).vect_f32)};
            accums[j] = {
                FusedMultiplyOpNeon<kIsSquaredL2>(q.val[1], vcvtq_f32_s32(data1), mult.val[1], (*accums[j]).vect_f32)};
        }
    }
    if constexpr (kNumDims == 4) {
        float32x4_t mult;
        if constexpr (kIsSquaredL2) {
            mult = vld1q_f32(inv_multipliers_for_squared_l2 + dim);
        }

        float32x4_t q = vld1q_f32(query + dim);
        for (size_t j : Seq(kUnrollBy)) {
            int8x16_t data = vld1q_s8(ptrs[j] + dim);
            int16x8_t data_l = vmovl_s8(vget_low_s8(data));
            int32x4_t data0 = vmovl_s16(vget_low_s16(data_l));
            accums[j] = {FusedMultiplyOpNeon<kIsSquaredL2>(q, vcvtq_f32_s32(data0), mult, (*accums[j]).vect_f32)};
        }
    }

    return accums;
}